
include("../src/qmcmary.jl")
using ..qmcmary

using Test
using LinearAlgebra


"""
计算G
"""
function rawgf(allbmats, idx)
    Nt = length(allbmats)
    Nx2 = size(allbmats[1])[1]
    Rs = I(Nx2)#allbmats[1]
    for i in Base.OneTo(idx)
        Rs = allbmats[i]*Rs
    end
    Ls = Diagonal(ones(Nx2))
    for i in Nt:-1:(idx+1)
        Ls = Ls*allbmats[i]
    end
    return inv(Diagonal(ones(Nx2)) + Rs*Ls)
end

"""
计算G
"""
function rawgft(allbmats, idx)
    # g(t,0) = <c(t) c+(0)> = B(t,0)G(0,0)
    # g(0,t) = -<c+(t) c(0)> = −(1−G(0,0))B^−1(t,0) 
    g0 = rawgf(allbmats, 0)
    if idx > 0
        bt0 = prod(reverse(allbmats[1:idx]))
    else
        bt0 = I(size(allbmats[end])[1])
    end
    gt0 = bt0*g0
    ind = I(size(allbmats[end])[1])
    g0t = -(ind-g0)*inv(bt0)
    return g0t, gt0
end


function local_update_test()
    ehk = [0 1; 1 0]
    #创建ss
    Nx = 2
    Nt = 20
    Ng = 5
    sslen = Int(Nt//Ng)
    bmats = Vector{Matrix{Float64}}(undef, Ng)
    allbmats = Vector{Matrix{Float64}}(undef, Nt)
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    #
    for bi in Base.OneTo(Ng)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
        #println(bi)
        bidxs[bi] = bi
        bmats[bi] = bmat_Ising(ehk, cfg)#create_bmat(bi+0)
        allbmats[bi] = bmats[bi]
    end
    ss = ScrollSVD(bmats)
    for si in Base.OneTo(sslen-1)
        for bi in Base.OneTo(Ng)
            cfg = 2(round.(rand(Nx)) .- 0.5)
            hscfg[bi+si*Ng, :] .=cfg
            #println(bi+si*Ng)
            bidxs[bi] = bi+si*Ng
            bmats[bi] = bmat_Ising(ehk, cfg)
            allbmats[bi+si*Ng] = bmats[bi]
        end
        push!(ss, bmats)
    end
    gf, ph = eq_green_scratch(ss)
    #变化之前的gf和det
    gf = qmcmary.down_prog_green(gf, allbmats[20])
    dpst = det(Diagonal(ones(4))+prod(reverse(allbmats)))
    pht = dpst / abs(dpst)
    @test isapprox(ph, pht)
    #对矩阵产生一个随机的变化
    dmat = Diagonal(rand(4).-0.5)
    bm2 = (Diagonal(ones(4))+dmat)*allbmats[19]
    #println(bm2)
    #return
    allbmats[19] = bm2
    newgf = qmcmary.rawgf(allbmats, 19)
    ##Sherman-Morrison只能对一个奏效
    println(dmat)
    #tmp = dmat*(Diagonal(ones(4))-gf)
    #gfupdated = gf - gf*tmp/(1+tr(tmp))
    gfupdated = qmcmary.ShermanMorrison(dmat, gf)
    println(newgf)
    println(gfupdated)
    dnow1 = det(Diagonal(ones(4))+prod(reverse(allbmats)))
    dnow2 = qmcmary.Woodbury(dmat, gf)
    println(dnow1/dpst, " ", dnow2)
    @test all(isapprox.(newgf, gfupdated, atol=1e-10))
    @test isapprox(dnow1/dpst, dnow2, atol=1e-10)
    #测试PressShermanMorrison和非对角Woodbury
    idx = [1, 3]
    dpst = dnow1
    dmat = rand(Float64, 2, 2).-0.5
    dmatall = zeros(Float64, 4, 4)
    for x in Base.OneTo(length(idx))
        for y in Base.OneTo(length(idx))
            i, j = idx[x], idx[y]
            dmatall[i, j] = dmat[x, y]
        end
    end
    #println(dmat)
    #println(dmatall)
    bm2 = (Diagonal(ones(4))+dmatall)*allbmats[19]
    allbmats[19] = bm2
    newgf = qmcmary.rawgf(allbmats, 19)
    dnow1 = det(Diagonal(ones(4))+prod(reverse(allbmats)))
    dnow2 = qmcmary.Woodbury(idx, dmat, gfupdated)
    println(dnow1/dpst, " ", dnow2)
    @test isapprox(dnow1/dpst, dnow2)
    gfupdated2 = qmcmary.PressShermanMorrison(idx, dmat, gfupdated)
    println(gfupdated2, newgf)
    @test all(isapprox.(newgf, gfupdated2))
    return
end



function example()
    #trace phase
    #init Green function Gf G[0, 0]_[i, j]
    Nx = 3
    hk = rand(ComplexF64, Nx, Nx)
    hk += adjoint(hk)
    ehk = exp(-0.1*hk)
    ehk = kron([1 0; 0 1], ehk)
    Ui = ones(Nx)
    the = rand(Nx)*pi
    nx = zeros(Nx)
    #注意如果这个（nx，ny，nz）向量长度不是1，expV会出问题，
    #这时的deltamat就会不等于exp2V
    ny = sin.(the)#[0.5sqrt(2), 0.5sqrt(2)]#
    nz = cos.(the)#[0.5sqrt(2), 0.5sqrt(2)]#cos.(the)
    ny2 = [ny[1], ny[2]+0.0, ny[3]]
    nz2 = [nz[1], nz[2], nz[3]+0.01]
    #ny2= sin.([the[1], the[2]+0.01, the[3]])#[ny[1], ny[2]+0.01]
    #nz2= cos.([the[1], the[2]+0.01, the[3]])#[nz[1], nz[2]]
    #创建ss
    Nt = 20
    Ng = 5
    sslen = Int(Nt//Ng)
    bmats = Vector{Matrix{ComplexF64}}(undef, Ng)
    allbmats = Vector{Matrix{ComplexF64}}(undef, Nt)
    bmats2 = Vector{Matrix{ComplexF64}}(undef, Ng)
    allbmats2 = Vector{Matrix{ComplexF64}}(undef, Nt)
    bidxs = Vector{Int}(undef, Ng)
    hscfg = Matrix{Int}(undef, Nt, Nx)
    #
    for bi in Base.OneTo(Ng)
        cfg = 2(round.(rand(Nx)) .- 0.5)
        hscfg[bi, :] .= cfg
        #println(bi)
        bidxs[bi] = bi
        bmats[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz, 0.1)#create_bmat(bi+0)
        #tm = bmat_Ising(ehk, cfg)
        #@assert all(isapprox.(tm, bmats[bi]))
        allbmats[bi] = bmats[bi]
        #
        bmats2[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny2, nz2, 0.1)
        allbmats2[bi] = bmats2[bi]
    end
    ss = ScrollSVD(bmats)
    ss2 = ScrollSVD(bmats2)
    for si in Base.OneTo(sslen-1)
        for bi in Base.OneTo(Ng)
            cfg = 2(round.(rand(Nx)) .- 0.5)
            hscfg[bi+si*Ng, :] .=cfg
            #println(bi+si*Ng)
            bidxs[bi] = bi+si*Ng
            bmats[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz, 0.1)
            #tm = bmat_Ising(ehk, cfg)
            #@assert all(isapprox.(tm, bmats[bi]))
            allbmats[bi+si*Ng] = bmats[bi]
            #
            bmats2[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny2, nz2, 0.1)
            allbmats2[bi+si*Ng] = bmats2[bi]
        end
        push!(ss, bmats)
        push!(ss2, bmats2)
    end
    #
    sp = default_splitting(Nt, hk, Ui)
    tapemap = pbpn_calc_tapemap(sp, nx, ny, nz)
    nxbar, nybar, nzbar = meas_grad(ss, allbmats, sp, hscfg, nx, ny, nz; tapemap=tapemap)
    sgn1 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
    sgn2 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats2)))
    println(-log(sgn2) + log(sgn1))
    println(nybar, nzbar)
    thetabar = @. nybar * cos(the) -  nzbar * sin(the)
    println(thetabar)
    return
    #
    gf, ph = eq_green_scratch(ss)
    #测试不等时格林函数，有时间再测试
    allbidxs, isR2L = qmcstatus(Nt)
    #println(allbidxs, isR2L)
    #g0t, gt0 = rawgft(allbmats, 0)
    #g0t2, gt02 = ueq_green_scratch(ss)
    #println(gf)
    #println(g0t)
    #println(gt0)
    #println(g0t2)
    #println(gt02)
    #return
    #最开始，push到ss里面N个B，生成相应的F
    #这时
    #F[N]=B[N]...B[1] F[N-1]=B[N-1]...B[1]，所有的B都放在R里面
    #然后先进行R->L的更新
    #更新后传播，每次更新的都是一组B矩阵，和ss的一致
    #G(beta-dt, beta-dt) = B^-1(beta, beta-dt)G(beta, beta)B(beta, beta-dt)
    #然后进行更新，更新完ngroup个后，计算G(beta-Ndt, beta-Ndt)。
    #【这时beta-dt到beta-Ndt都已经更新过了】
    #重新计算F[N] = B[N]放到L里面
    #一直到都更新完，更新到F[1] = B[N]...B[1]所有的都到L中了
    #【更新G(0, 0)时其实更新的是B[N]...B[1]中的B[N]，因为G(beta, beta)=G(0, 0)】
    #L->R
    #先传播，把1放到前面来，然后更新，更新Ng个以后变成一组
    #这时重新计算G(0, 0)，和之前不同的是更新完了以后，传播G：G(dt, dt)
    #更新到G(Ndt-dt, Ndt-dt)时【这时0到Ndt-dt都更新过】
    #重新计算F[1]=B[1]放到R里面，F[2]=B[N]..B[2], 然后计算G(Ndt, Ndt)
    #随后一直更新到F[N]=B[N]..B[1]并全都在R里面
    #注意L是从i到N，R是i到1
    for si in Base.OneTo(sslen)
        # 传播 Gf; G[l,l]_[i,j] -> G[l-1,l-1]_[i,j]
        # trace phase
        ## local update, accept ration calculate by G[l,l]_[i,j]
        #计算新的bmats，推进ss
        #一直到所有都在L中，再从G(0, 0)开始推进
        # stablize Gf, directlly calculate G[l,l]_[i,j]
        for bi in Base.OneTo(Ng)
            tau = Nt - (Ng*(si-1)+bi) + 1
            cfg = hscfg[tau, :]
            #gf2 = rawgf(allbmats, 19)
            #println(gf)
            #println(gf2)
            for xi in Base.OneTo(Nx)
                sp1 = cfg[xi]
                #println(tau, " ", xi, " ", sp1)
                idx, dmat = Δmat_IsingND(length(cfg), xi, -sp1, sp1,
                Ui[xi], nx[xi], ny[xi], nz[xi])
                #dmat = rand(ComplexF64, 2, 2).-0.5
                #println(dmat)
                #dt = Δmat_Ising(length(cfg), xi, -sp1, sp1)
                #dt1 = zeros(ComplexF64, 2*Nx, 2*Nx)
                #dt1[idx[1], idx[1]] = dmat[1, 1]
                #dt1[idx[1], idx[2]] = dmat[1, 2]
                #dt1[idx[2], idx[1]] = dmat[2, 1]
                #dt1[idx[2], idx[2]] = dmat[2, 2]
                #println((dt2+I(4))*allbmats[tau])
                #cfg2 = copy(cfg)
                #cfg2[xi] = -sp1
                #println(bmat_IsingND(ehk, cfg2, Ui, nx, ny, nz))
                #return
                #println(dt)
                #println(dt2)
                #@assert all(isapprox.(dt, dt2, atol=1e-16))
                #println(dmat)
                #println(gf)
                ratio = Woodbury(idx, dmat, gf)
                #测试
                sgn1 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
                cfg2 = copy(cfg)
                cfg2[xi] = -sp1
                allbmats[tau] = bmat_IsingND(ehk, cfg2, Ui, nx, ny, nz)
                sgn2 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
                println(ratio, " ", sgn2/sgn1)
                @test isapprox(ratio, sgn2/sgn1)
                ph2 = ph*ratio/abs(ratio)
                @test isapprox(sgn2/abs(sgn2), ph2)
                ##
                allbmats[tau] = bmat_IsingND(ehk, cfg2, Ui, nx, ny, nz)
                gf1 = PressShermanMorrison!(idx, dmat, gf)
                gf2 = rawgf(allbmats, tau)
                println(gf1 - gf2)
                @assert all(isapprox.(gf1, gf2, atol=1e-12))
                allbmats[tau] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz)
            end
            #重新计算更新过的
            bmats[Ng-bi+1] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz)
            allbmats[tau] = bmats[Ng-bi+1]
            gf = down_prog_green(gf, allbmats[tau])
            #
            allbidxs, isR2L = qmcstatus(Nt, allbidxs, isR2L)
            println(allbidxs, isR2L)
        end
        scrollR2L(ss, bmats)
        gfold = gf
        gf, ph = eq_green_scratch(ss)
        println(gf, gfold)
        #exit()
    end
    #从L到R
    for si in Base.OneTo(sslen)
        for bi in Base.OneTo(Ng)
            tau = Ng*(si-1)+bi
            cfg = hscfg[tau, :]
            gf = up_prog_green(gf, allbmats[tau])
            #这里优先更新gf，这样就是从1开始更新而不是从0(beta)开始
            println(tau)
            for xi in Base.OneTo(Nx)
                sp1 = cfg[xi]
                #println(tau, " ", xi, " ", sp1)
                idx, dmat = Δmat_IsingND(length(cfg), xi, -sp1, sp1,
                Ui[xi], nx[xi], ny[xi], nz[xi])
                #
                ratio = Woodbury(idx, dmat, gf)
                if rand() > abs(ratio)
                    continue
                end
                #测试
                sgn1 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
                cfg2 = copy(cfg)
                cfg2[xi] = -sp1
                allbmats[tau] = bmat_IsingND(ehk, cfg2, Ui, nx, ny, nz)
                sgn2 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
                println(ratio, " ", sgn2/sgn1)
                @test isapprox(ratio, sgn2/sgn1)
                ph2 = ph*ratio/abs(ratio)
                @test isapprox(sgn2/abs(sgn2), ph2)
                ##
                allbmats[tau] = bmat_IsingND(ehk, cfg2, Ui, nx, ny, nz)
                gf1 = PressShermanMorrison!(idx, dmat, gf)
                gf2 = rawgf(allbmats, tau)
                println(gf1 - gf2)
                @test all(isapprox.(gf1, gf2, atol=1e-8))
                #执行更新
                #allbmats[tau] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz)
                hscfg[tau, xi] = -sp1
                cfg[xi] = -sp1
                gf = gf1
                ph = ph2
            end
            #重新计算B矩阵
            bmats[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny, nz)
            allbmats[tau] = bmats[bi]
            allbidxs, isR2L = qmcstatus(Nt, allbidxs, isR2L)
            println(allbidxs, isR2L)
        end
        scrollL2R(ss, bmats)
        gfold = gf
        gf, ph = eq_green_scratch(ss)
        println(gf, gfold)
    end
    gfold = gf
    gf, ph = eq_green_scratch(ss)
    println(gf, gfold)
    #
    #
    for bi in Base.OneTo(Ng)
        cfg = hscfg[bi, :]
        bmats2[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny2, nz2)
        allbmats2[bi] = bmats2[bi]
    end
    ss2 = ScrollSVD(bmats2)
    for si in Base.OneTo(sslen-1)
        for bi in Base.OneTo(Ng)
            cfg = hscfg[bi+si*Ng, :]
            bmats2[bi] = bmat_IsingND(ehk, cfg, Ui, nx, ny2, nz2)
            allbmats2[bi+si*Ng] = bmats2[bi]
        end
        push!(ss2, bmats2)
    end
    #
    nxbar, nybar, nzbar = meas_grad(ss, allbmats, ehk, hscfg, Ui, nx, ny, nz)
    sgn1 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats)))
    sgn2 = det(Diagonal(ones(2*Nx))+prod(reverse(allbmats2)))
    println(-log(sgn2) + log(sgn1))
    println(nybar, nzbar)
    thetabar = @. nybar * cos(the) -  nzbar * sin(the)
    println(thetabar)
end


#for _ in Base.OneTo(100)
    #local_update_test()
    example()
#end


